#########################################################################################################.
###                                                                                                   ###
###     Rage Against the Machine? Generative AI Exposure, Subjective Risk, and Policy Preferences     ###
###     Journal of European Public Policy                                                             ###
###                                                                                                   ###
###     Haslberger, Gingrich & Bhatia                                                                 ###
###                                                                                                   ###
###     Helper Functions                                                                              ###
###                                                                                                   ###
#########################################################################################################.


# Global Macros ----

# Colour scheme
onecolour <- c("#004488")

twocolours <- c("#004488", "#BB5566")

threecolours <- c("#004488","#DDAA33","#BB5566")

# Dodge for figures
dodge_width <- 0.6

#################.
# FUNCTIONS ----
#################.

# Function to extract coefficient and confidence intervals from a model
extract_model_info <- function(model) {
  # Extract the name of the treatment coefficient
  treatment_coef_name <- grep("treatment", names(coef(model)), value = TRUE)
  
  if(length(treatment_coef_name) == 0) {
    stop("Could not find a coefficient for 'treatment'.")
  }
  
  coef_val <- coef(model)[treatment_coef_name]
  conf_int_95 <- confint(model, level = 0.95)[treatment_coef_name, ]
  conf_int_90 <- confint(model, level = 0.90)[treatment_coef_name, ]
  return(c(coef_val, conf_int_95[1], conf_int_95[2], conf_int_90[1], conf_int_90[2]))
}

# Familiarity with tasks
plot_canvas_simtask <- function(df = data, q = q, cat = cat, y = after_stat(prop), labs = "label", ylimits = c(0, .6)){
  
  ggplot(data=df) + 
    geom_bar(mapping = aes(x = q, y = after_stat(prop), fill = cat), 
             stat = "count", position = "dodge", width = .5) + 
    xlab("") + ylab("")+
    labs(title=labs) +
    # labs(subtitle = 'frequency in <span style = "color:#004488">**control group**</span> and <span style = "color:#DDAA33">**treatment group**</span>') +
    scale_x_continuous(guide = guide_axis(n.dodge = 1), 
                       labels = function(x) stringr::str_wrap(c("Frequently", 
                                                                "Sometimes", 
                                                                "Never"), width = 3),
                       breaks = c(1,2,3)) +
    scale_y_continuous(limits = ylimits,labels = scales::percent_format(accuracy = 1)) +
    scale_color_manual(values = twocolours) +
    scale_fill_manual(values = twocolours) +
    theme_minimal() +
    theme(axis.text=element_text(size=18),
          axis.text.x = element_text(angle = 45, hjust = 1),
          axis.title=element_text(size=18),
          title=element_text(size=20),
          panel.border = element_blank(), panel.grid.major = element_blank(),
          panel.grid.minor = element_blank(), axis.line = element_line(colour = "black"),
          plot.subtitle = ggtext::element_markdown(size = 18, lineheight = 1.3),
          legend.position = "none")
}

plot_canvas_useful <- function(df = data, q = q, cat = cat, y = after_stat(prop), labs = "label", ylimits = c(0, 1)){
  
  ggplot(data=df) + 
    geom_bar(mapping = aes(x = q, y = after_stat(prop), fill = cat), 
             stat = "count", position = "dodge", width = .5) + 
    xlab("") + ylab("")+
    labs(title=labs) +
    scale_x_continuous(guide = guide_axis(n.dodge = 1), 
                       labels = function(x) stringr::str_wrap(c("Not used", 
                                                                "Not useful", 
                                                                "Uselful"), width = 3),
                       breaks = c(1,2,3)) +
    scale_y_continuous(limits = ylimits,labels = scales::percent_format(accuracy = 1)) +
    scale_color_manual(values = twocolours) +
    scale_fill_manual(values = twocolours) +
    theme_minimal() +
    theme(axis.text=element_text(size=18),
          axis.text.x = element_text(angle = 45, hjust = 1),
          axis.title=element_text(size=18),
          title=element_text(size=20),
          panel.border = element_blank(), panel.grid.major = element_blank(),
          panel.grid.minor = element_blank(), axis.line = element_line(colour = "black"),
          plot.subtitle = ggtext::element_markdown(size = 18, lineheight = 1.3),
          legend.position = "none")
}




# Interaction figures ----
extract_marginal_effects <- function(model, Xvar, Zvar) {
  # model : a fitted 'lm' or 'ivreg' object with formula: outcome ~ Xvar*Zvar
  # Xvar  : the quoted name of your factor (levels c("0","1"))
  # Zvar  : the quoted name of your factor (levels c("0","1"))
  #
  # This function:
  #  1) Identifies the baseline group (X=0, Z=0).
  #  2) Computes differences from baseline for the three other combos:
  #       (1,0), (0,1), (1,1).
  #  3) Returns a data.frame with columns:
  #       X_level, Z_level, difference, std_error,
  #       ci90_lower, ci90_upper, ci95_lower, ci95_upper.
  
  coefs <- coef(model)
  vmat  <- vcov(model)
  df    <- df.residual(model)
  
  # T-values for confidence intervals
  t95 <- qt(0.975, df)  # 95%
  t90 <- qt(0.95,  df)  # 90%
  
  # We expect these four coefficients (typical naming):
  #  (Intercept)
  #  Xvar1
  #  Zvar1
  #  Xvar1:Zvar1   (or Zvar1:Xvar1)
  #
  # We'll regex-search:
  int_name <- "(Intercept)"
  
  X_name <- grep(paste0("^", Xvar, "1$"), names(coefs), value = TRUE)
  Z_name <- grep(paste0("^", Zvar, "1$"), names(coefs), value = TRUE)
  
  # Interaction can appear as "Xvar1:Zvar1" or "Zvar1:Xvar1"
  int_pattern_1 <- paste0(Xvar, "1:", Zvar, "1")
  int_pattern_2 <- paste0(Zvar, "1:", Xvar, "1")
  XZ_name <- grep(paste(int_pattern_1, int_pattern_2, sep="|"), names(coefs), value = TRUE)
  
  # Safety checks
  if (!int_name %in% names(coefs)) {
    stop("Intercept not found in coefficients.")
  }
  if (length(X_name) == 0) {
    stop("Couldn't find a main-effect coefficient for '", Xvar, "=1'. Check factor coding or naming.")
  }
  if (length(Z_name) == 0) {
    stop("Couldn't find a main-effect coefficient for '", Zvar, "=1'. Check factor coding or naming.")
  }
  if (length(XZ_name) == 0) {
    warning("No interaction term found for X=1, Z=1. If your formula was ~ Xvar*Zvar, 
             you should see an interaction. Possibly naming mismatch.")
    XZ_name <- NA
  }
  
  # We'll define a helper function to compute "difference from baseline" for each combo
  # by summing the relevant coefficients and their covariances:
  get_diff_var <- function(coef_names) {
    if (!all(coef_names %in% names(coefs))) {
      # If we can't find them, difference=0
      return(list(est=0, var=0))
    }
    b_sum <- sum(coefs[coef_names])
    sub_mat <- vmat[coef_names, coef_names, drop=FALSE]
    var_sum <- sum(sub_mat)  # sum of diag + 2*off-diag
    list(est=b_sum, var=var_sum)
  }
  
  # The baseline group (X=0,Z=0) is difference=0 from itself, so we skip it:
  # We'll only produce the combos (1,0), (0,1), (1,1).
  
  # 1) (X=1, Z=0) => difference = bX
  combo_10 <- get_diff_var(X_name)
  
  # 2) (X=0, Z=1) => difference = bZ
  combo_01 <- get_diff_var(Z_name)
  
  # 3) (X=1, Z=1) => difference = bX + bZ + bXZ
  if (is.na(XZ_name)) {
    # no interaction found, fallback to bX + bZ
    combo_11 <- get_diff_var(c(X_name, Z_name))
  } else {
    combo_11 <- get_diff_var(c(X_name, Z_name, XZ_name))
  }
  
  # Build a mini data frame for each
  build_row <- function(xl, zl, est, var_est) {
    se_est <- sqrt(var_est)
    c95l <- est - t95*se_est
    c95u <- est + t95*se_est
    c90l <- est - t90*se_est
    c90u <- est + t90*se_est
    data.frame(
      X_level     = xl,
      Z_level     = zl,
      difference  = est,
      std_error   = se_est,
      ci90_lower  = c90l,
      ci90_upper  = c90u,
      ci95_lower  = c95l,
      ci95_upper  = c95u
    )
  }
  
  df_10 <- build_row("1","0", combo_10$est, combo_10$var)
  df_01 <- build_row("0","1", combo_01$est, combo_01$var)
  df_11 <- build_row("1","1", combo_11$est, combo_11$var)
  
  # Combine only these three rows
  out <- rbind(df_10, df_01, df_11)
  rownames(out) <- NULL
  out
}


# Function for title grobs to use in final figures
make_title_grob <- function(title_text = "My Title", font_size = 16) {
  # Wrap the character string in a plotmath expression: bold(underline(...))
  # The bquote(...) call inserts the user-supplied title_text into the expression.
  title_expr <- bquote(bold(underline(.(title_text))))
  
  textGrob(
    label = title_expr,
    gp = gpar(fontsize = font_size)
  )
}



# Customizable function to get legend grob
make_legend_grob <- function(
    shape_mapping = c("Male, treated" = 15,
                      "Female, control" = 16,
                      "Female, treated" = 17),
    shape_labels  = NULL,
    size          = 5,
    stroke        = 1,
    legend_title  = NULL
) {
  # 1) Build a tiny data frame to hold factor levels & dummy coords
  df_legend <- data.frame(
    group = factor(names(shape_mapping), levels = names(shape_mapping)),
    x = 1,
    y = seq_along(shape_mapping)  # arbitrary
  )
  
  # 2) If no shape_labels given, use the same as the names in shape_mapping
  if (is.null(shape_labels)) {
    shape_labels <- names(shape_mapping)
    names(shape_labels) <- names(shape_mapping)
  }
  
  # 3) Create a small ggplot:
  p_legend <- ggplot(df_legend, aes(x = x, y = y, shape = group)) +
    geom_point(size = size, stroke = stroke, color = onecolour) +
    scale_shape_manual(
      values = shape_mapping,
      labels = shape_labels,
      name   = legend_title
    ) +
    # Minimal theme so we only keep the legend
    theme_void() +
    theme(legend.position = "bottom",
          legend.text = element_text(size = 14))  # or "bottom", etc.
  
  # 4) Extract legend as a grob
  legend_g <- cowplot::get_legend(p_legend)
  return(legend_g)
}